"""Turn Taking evaluation library.

This library takes as input two files:
the first file has the likelihood predicted by our turn-taking model,
and the second file has turn-taking decisions made by AI and human.
Both these files have their respective annotations at every 40ms chunk.

It computes our proposed evaluation metrics that perform pairwise
comparison of 2 turn-taking events using "soft" labels and
overall metrics (namely F1 and confusion matrix) using "hard" labels
(All thresholds are tuned on the in-domain validation set).
"""

import dataclasses
import enum

import matplotlib.pyplot as plt
import numpy as np
from sklearn.metrics import (
    ConfusionMatrixDisplay,
    accuracy_score,
    classification_report,
    confusion_matrix,
    f1_score,
    roc_auc_score,
)


class LabelIndex(enum.Enum):
    # Continuation
    C = 0
    # Silence
    NA = 1
    # Interruption
    IN = 2
    # Backchannel
    BC = 3
    # Turn Change
    T = 4


class LabelThreshold(enum.Enum):  # Thresholds to assign hard labels
    # Continuation
    C = 0.2
    # Silence
    NA = 0.45
    # Interruption
    IN = 0.4
    # Backchannel
    BC = 0.4
    # Turn Change
    T = 0.4


class TurnLabel(enum.Enum):  # Speaker Turn Labels
    # AI is speaking
    AI = "A"
    # Human is speaking
    HUM = "B"
    # Overlapping speech where AI was interrupted by Human
    AI_HUM = "AB"
    # Overlapping speech where Human was interrupted by AI
    HUM_AI = "BA"
    # Silence at beginning i.e. no one has started speaking
    NA = "NA"


class MetricThreshold(enum.Enum):  # Thresholds for proposed turn-taking metrics
    # Turn change vs Continuation
    turn_change = 0
    # Backchannel vs No BackChannel
    backchannel = 0.1
    # Interruption vs Continuation
    interrupt = -0.45
    # Successful vs Unsuccessful Interruption
    success_interrupt = -0.1


class ModelParam(
    enum.Enum
):  # Parameters of Turn Taking Model (Need to be configured based on the model)
    min_start_time = 0.2  # Turn taking model makes prediction after first 0.2 sec
    chunk_length = 0.04  # Turn taking model makes prediction at every 0.04 seconds


def compute_turn_likelihoods(ref_arr, min_start_time, chunk_length):
    """
    Generate a likelihood dictionary of turn-taking events based on
    predictions from the turn-taking model.

    Args:
        ref_arr (array): Lines of file containing
                likelihood predictions generated by the
                turn-taking model.
                Each line begins with the name of an audio file,
                followed by space-separated
                likelihood predictions for each "chunk".
                The likelihood prediction for each chunk
                is an array storing the likelihood
                of Label "L" at LabelIndex."L".value.
                Therefore, each entry is structured as:
                  file_id [Likelihood array of chunk 1] [Likelihood array of chunk 2]
                Example entry in Moshi_full_likelihood.txt:
                  moshi_audio_1.wav 0.52,0.25,0.03,0.04,0.16 0.20,0.62,0.02,0.02,0.14
        min_start_time (float): Parameter of the turn-taking model (see ModelParam).
        chunk_length (float): Parameter of the turn-taking model (see ModelParam).

    Returns:
        true_dict: A dictionary containing the likelihood of turn-taking events.
    """
    true_dict = {}  # Dictionary to store the likelihood of turn-taking events
    # Structure:
    #   - Key: File Name
    #   - Value: Likelihood dictionary for the corresponding file
    # Structure of the Likelihood dictionary for each file:
    #   - Key: End time of a chunk
    #     (Note: annotations are computed every "ModelParam.chunk_length" seconds)
    #   - Value: Array storing the likelihood of Label "L" at LabelIndex."L".value
    for line in ref_arr:
        line1 = line.strip().split()
        line_id = line1[0].replace("sw0", "")
        if line_id not in true_dict:
            true_dict[line_id] = {}
            for count in range(len(line1) - 1):
                end_chunk = float(
                    f"{(min_start_time + (count + 1) * chunk_length):.2f}"
                )
                true_dict[line_id][end_chunk] = [
                    float(k) for k in line1[1 + count].split(",")
                ]
    return true_dict


def compute_turn_decisions(hyp_arr):
    """
    Generate a dictionary of turn-taking decisions in Human AI conversation.

    Args:
        hyp_arr (array): Lines of file containing
                turn-taking decisions in Human AI conversation.
                Each entry corresponds to "ModelParam.chunk_length" seconds
                and is structured as:
                  file_id,[start time],[end time],[Turn Taking Event],[Speaker Turn]
                Example entry in Moshi_full_likelihood.txt:
                  moshi_audio_1,0.96,1.0,C,A

    Returns:
        pred_dict: A dictionary containing the turn-taking decisions
        turn_dict: A dictionary containing the speaker turns at each chunk
    """
    pred_dict = {}  # Dictionary to store the turn taking decisions made by human and AI
    # Structure:
    #   - Key: File Name
    #   - Value: Turn taking decisions dictionary for the corresponding file
    # Structure of the Turn taking decisions dictionary for each file:
    #   - Key: End time of a chunk
    #   (Note: annotations are computed every "ModelParam.chunk_length" seconds)
    #   - Value: Any one of the Labels (see LabelIndex)
    turn_dict = {}  # Dictionary to store the speaker turns
    # Structure:
    #   - Key: File Name
    #   - Value: Speaker turns dictionary for the corresponding file
    # Structure of the Speaker turns dictionary for each file:
    #   - Key: End time of a chunk
    #   (Note: annotations are computed every "ModelParam.chunk_length" seconds)
    #   - Value: Any one of the speaker turn labels (see TurnLabel)
    for line in hyp_arr:
        line1 = line.strip().split(",")
        if line1[0] not in pred_dict:
            pred_dict[line1[0]] = {}
            turn_dict[line1[0]] = {}
        if line1[3] == "I":
            line1[3] = "IN"
        pred_dict[line1[0]][float(line1[2])] = line1[3]
        turn_dict[line1[0]][float(line1[2])] = line1[-1]
    return pred_dict, turn_dict


def assign_hard_label(likelihood_arr):
    """
    Generate hard turn-taking labels from likelihood probabilities
    using thresholds tuned on validation set (See LabelThreshold)

    Args:
        likelihood_arr (Array[float]): Array storing the likelihood of
        Label "L" at LabelIndex."L".value

    Returns:
        Any one of the Labels (see LabelIndex)
    """
    if likelihood_arr[LabelIndex.C.value] > LabelThreshold.C.value:
        return "C"
    elif likelihood_arr[LabelIndex.NA.value] > LabelThreshold.NA.value:
        return "NA"
    elif likelihood_arr[LabelIndex.IN.value] > LabelThreshold.IN.value:
        return "IN"
    elif likelihood_arr[LabelIndex.BC.value] > LabelThreshold.BC.value:
        return "BC"
    elif likelihood_arr[LabelIndex.T.value] > LabelThreshold.T.value:
        return "T"
    else:
        return LabelIndex(np.argmax(likelihood_arr)).name


@dataclasses.dataclass
class ScoreResult:
    def __init__(
        self,
        true_dict: dict,
        pred_dict: dict,
        turn_dict: dict,
        labels,
        only_AI: bool = False,
        only_human: bool = False,
        human_human: bool = False,
    ) -> None:
        self.labels = labels
        self.only_AI = only_AI  # True if metrics computed only for decisions made by AI
        self.only_human = (
            only_human  # True if metrics computed only for decisions made by human
        )
        self.human_human = (
            human_human  # True if metrics are being computed on
            # human-human conversation to evaluate judge model
        )
        if self.human_human:
            (
                self.true_arr,
                self.pred_arr_hard_label,
                self.pred_arr_soft_label,
                self.turn_arr,
            ) = self.combined_pred_hyp_arr(pred_dict, true_dict, turn_dict)
        else:
            (
                self.pred_arr,
                self.true_arr_hard_label,
                self.true_arr_soft_label,
                self.turn_arr,
            ) = self.combined_pred_hyp_arr(true_dict, pred_dict, turn_dict)

    def combined_pred_hyp_arr(self, true_dict, pred_dict, turn_dict):
        """
        Generate arrays of pseudo-labels produced by the turn-taking model and
        turn-taking events made by humans/AI.
        Both arrays are ordered by file name
        and the end time of each chunk, ensuring that
        each index in both arrays corresponds to the same chunk.
        Args:
            true_dict (dict): A dictionary containing the likelihood of
            turn-taking events (Generated by `compute_gt`)
            pred_dict (dict): A dictionary containing the turn-taking decisions
            (Generated by `compute_pred`)
            turn_dict (dict): A dictionary containing the speaker turns at each chunk
            (Generated by `compute_pred`)

        Returns:
            pred_arr: Array containing the turn-taking decisions
            true_arr: Array containing the pseudo labels
            (soft/hard) from the turn-taking model
            turn_arr: Array containing the speaker turns
        """
        true_arr_hard_label = []
        true_arr_soft_label = []
        pred_arr = []
        turn_arr = []
        for k in true_dict:
            for j in true_dict[k]:
                if j not in pred_dict[k]:
                    print(j)
                    continue
                true_arr_hard_label.append(assign_hard_label(true_dict[k][j]))
                true_arr_soft_label.append(true_dict[k][j])

                pred_arr.append(pred_dict[k][j])
                turn_arr.append(turn_dict[k][j])

        pred_arr = np.array(pred_arr)
        true_arr_hard_label = np.array(true_arr_hard_label)
        true_arr_soft_label = np.array(true_arr_soft_label)
        turn_arr = np.array(turn_arr)
        if (
            self.only_AI
        ):  # Compute metrics only for decisions made by AI i.e. when human is speaking
            pred_arr = pred_arr[
                np.logical_or(
                    turn_arr == TurnLabel.HUM.value,
                    turn_arr == TurnLabel.HUM_AI.value,
                )
            ]
            true_arr_hard_label = true_arr_hard_label[
                np.logical_or(
                    turn_arr == TurnLabel.HUM.value,
                    turn_arr == TurnLabel.HUM_AI.value,
                )
            ]
            true_arr_soft_label = true_arr_soft_label[
                np.logical_or(
                    turn_arr == TurnLabel.HUM.value,
                    turn_arr == TurnLabel.HUM_AI.value,
                )
            ]
        elif (
            self.only_human
        ):  # Compute metrics only for decisions made by human i.e. when AI is speaking
            pred_arr = pred_arr[
                np.logical_or(
                    turn_arr == TurnLabel.AI.value,
                    turn_arr == TurnLabel.AI_HUM.value,
                )
            ]
            true_arr_hard_label = true_arr_hard_label[
                np.logical_or(
                    turn_arr == TurnLabel.AI.value,
                    turn_arr == TurnLabel.AI_HUM.value,
                )
            ]
            true_arr_soft_label = true_arr_soft_label[
                np.logical_or(
                    turn_arr == TurnLabel.AI.value,
                    turn_arr == TurnLabel.AI_HUM.value,
                )
            ]
        return pred_arr, true_arr_hard_label, true_arr_soft_label, turn_arr

    def compute_F1(self):
        """
        Computes Macro-F1 of the turn-taking events made by humans/AI
        using `hard` pseudo labels generated from turn-taking model.

        Returns:
            F1_dict: Dict with keys as label and values as macro F1 for that label
        """
        score = 0
        F1_dict = {}
        for k in self.labels:
            print("Label: " + k)
            if self.human_human:
                x = f1_score(
                    self.true_arr == k, self.pred_arr_hard_label == k, average="macro"
                )
            else:
                x = f1_score(
                    self.true_arr_hard_label == k, self.pred_arr == k, average="macro"
                )
            F1_dict[k] = round(x, 3)
            if self.human_human:
                print(
                    classification_report(
                        self.true_arr == k, self.pred_arr_hard_label == k
                    )
                )
            else:
                print(
                    classification_report(
                        self.true_arr_hard_label == k, self.pred_arr == k
                    )
                )
            print("Macro F1: " + str(round(x, 3)))
            print("--------")
            score += x
        print("Overall F1: " + str(round(score / len(self.labels), 3)))
        return F1_dict

    def compute_roc_auc(self):
        """
        Computes ROC_AUC of the turn-taking events.

        Returns:
            F1_dict: Dict with keys as label and values as macro F1 for that label
        """
        score = 0
        F1_dict = {}
        assert self.human_human is True
        for k in self.labels:
            print("Label: " + k)
            if k == "C":
                x = roc_auc_score(
                    self.true_arr == k, self.pred_arr_soft_label[:, LabelIndex.C.value]
                )
            elif k == "NA":
                x = roc_auc_score(
                    self.true_arr == k, self.pred_arr_soft_label[:, LabelIndex.NA.value]
                )
            elif k == "IN":
                x = roc_auc_score(
                    self.true_arr == k, self.pred_arr_soft_label[:, LabelIndex.IN.value]
                )
            elif k == "BC":
                x = roc_auc_score(
                    self.true_arr == k, self.pred_arr_soft_label[:, LabelIndex.BC.value]
                )
            elif k == "T":
                x = roc_auc_score(
                    self.true_arr == k, self.pred_arr_soft_label[:, LabelIndex.T.value]
                )
            F1_dict[k] = round(x, 3)
            print("ROC AUC: " + str(round(x, 3)))
            print("--------")
            score += x
        print("Overall ROC AUC: " + str(round(score / len(self.labels), 3)))
        return F1_dict

    def compute_confusion_matrix(self):
        """
        Computes Confusion Matrix of the turn-taking events made by humans/AI
        using `hard` pseudo labels generated from turn-taking model.

        Returns:
            cm: 2D confusion matrix array
        """
        cm = confusion_matrix(
            self.true_arr_hard_label, self.pred_arr, labels=self.labels
        )
        disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=self.labels)
        disp.plot()
        plt.show()
        return cm

    def turn_change_metric(self):
        """
        This metric aims to test the conversation capability:
        When user speaks -> when system should speak up?
        An ideal AI system should have higher turn change likelihood
        when it decides to speak up
        and higher continuation likelihood when it let user continue.
        Judge Label = Turn change if
        (turn change likelihood - continuation likelihood
         > MetricThreshold.turn_change.value)
        Else Continuation

        Returns:
            accuracy_pause: Accuracy when system let user continue.
            accuracy_turn_change: Accuracy when system decides to speak up.
        """
        pred_arr = []
        gt_arr = []
        for k, pred_label in enumerate(self.pred_arr):
            if self.turn_arr[k] == TurnLabel.HUM.value:  # Human is speaking
                if k < 3:
                    continue
                if (self.pred_arr[k - 3 : k] == "NA").all():  # Human pauses
                    if pred_label in ("C", "T"):
                        pred_arr.append(pred_label)
                        diff_val = (
                            self.true_arr_soft_label[k][LabelIndex.T.value]
                            - self.true_arr_soft_label[k][LabelIndex.C.value]
                        )
                        gt_arr.append(
                            "T" if diff_val > MetricThreshold.turn_change.value else "C"
                        )
        gt_arr = np.array(gt_arr)
        pred_arr = np.array(pred_arr)
        accuracy_pause = round(
            accuracy_score(gt_arr[pred_arr == "C"], pred_arr[pred_arr == "C"]) * 100, 1
        )  # AI let human continue
        print("Accuracy during Pause: " + str(accuracy_pause))
        accuracy_turn_change = round(
            accuracy_score(gt_arr[pred_arr == "T"], pred_arr[pred_arr == "T"]) * 100, 1
        )  # AI speaks up
        print("Accuracy during Turn Change: " + str(accuracy_turn_change))
        return (
            accuracy_pause,
            accuracy_turn_change,
        )

    def make_backchannel_metric(self):
        """
        This metric aims to test the conversation capability:
        When user speaks -> when system should backchannel?
        An ideal AI system should have higher backchannel likelihood
        when it decides to produce a backchannel
        and lower backchannel likelihood when it decides
        to not backchannel.
        Judge Label = Backchannel if
        (backchannel likelihood > MetricThreshold.turn_change.value)
        Else No Backchannel

        Returns:
            accuracy_backchannel: Mean accuracy
                when system decides to produce a backchannel.
            accuracy_no_backchannel: Mean accuracy
                when system decides not to produce a backchannel.
        """
        pred_arr = []
        gt_arr = []
        for k, pred_label in enumerate(self.pred_arr):
            if self.turn_arr[k] == TurnLabel.HUM.value:  # Human is speaking
                if k < 1:
                    continue
                if self.pred_arr[k - 1] != "BC":  # AI has not backchannel yet
                    if pred_label == "BC":
                        pred_arr.append(pred_label)
                    else:
                        pred_arr.append(
                            "X"
                        )  # Dummy label stating that judge label is to not backchannel
                    if (
                        self.true_arr_soft_label[k][LabelIndex.BC.value]
                        > MetricThreshold.backchannel.value
                    ):
                        gt_arr.append("BC")
                    else:
                        gt_arr.append("X")
        gt_arr = np.array(gt_arr)
        pred_arr = np.array(pred_arr)
        accuracy_backchannel = round(
            accuracy_score(gt_arr[pred_arr == "BC"], pred_arr[pred_arr == "BC"]) * 100,
            1,
        )  # AI starts to backchannel
        print("Accuracy during BackChannel: " + str(accuracy_backchannel))
        accuracy_no_backchannel = round(
            accuracy_score(gt_arr[pred_arr != "BC"], pred_arr[pred_arr != "BC"]) * 100,
            1,
        )  # AI does not backchannel
        print("Accuracy during No Backchannel: " + str(accuracy_no_backchannel))
        return (
            accuracy_backchannel,
            accuracy_no_backchannel,
        )

    def make_interruption_metric(self):
        """
        This metric aims to test the conversation capability:
        When user speaks -> when system should interrupt?
        An ideal AI system should have higher interruption likelihood
        when it decides to make an interruption
        and higher continuation likelihood when it lets user continue.

        Returns:
            accuracy_interrupt: Mean accuracy
                when system decides to interrupt.
            accuracy_no_interrupt: Mean accuracy
                when system does not interrupt and let user continue.
        """
        pred_arr = []
        gt_arr = []
        for k, pred_label in enumerate(self.pred_arr):
            if (
                self.pred_arr[k - 1] == "C" and self.turn_arr[k] == TurnLabel.HUM.value
            ):  # Human is speaking
                if pred_label in ("C", "IN"):
                    pred_arr.append(pred_label)
                    diff_val = (
                        self.true_arr_soft_label[k][LabelIndex.IN.value]
                        - self.true_arr_soft_label[k][LabelIndex.C.value]
                    )
                    gt_arr.append(
                        "IN" if diff_val > MetricThreshold.interrupt.value else "C"
                    )
        gt_arr = np.array(gt_arr)
        pred_arr = np.array(pred_arr)
        accuracy_interrupt = round(
            accuracy_score(gt_arr[pred_arr == "IN"], pred_arr[pred_arr == "IN"]) * 100,
            1,
        )  # AI interrupt
        print("Accuracy during Interruption: " + str(accuracy_interrupt))
        accuracy_no_interrupt = round(
            accuracy_score(gt_arr[pred_arr == "C"], pred_arr[pred_arr == "C"]) * 100, 1
        )  # AI let human speak
        print("Accuracy during No Interruption: " + str(accuracy_no_interrupt))
        return (
            accuracy_interrupt,
            accuracy_no_interrupt,
        )

    def turn_willingness_metric(self):
        """
        This metric aims to test the conversation capability:
        When system speaks -> Convey user when it can speak up?
        For an ideal AI system, turn change likelihood should be high
        when user speaks up and continuation likelihood
        should be high when system continues.

        Returns:
            accuracy_pause: Accuracy when system continues.
            accuracy_turn_change: Accuracy when user speaks up.
        """
        pred_arr = []
        gt_arr = []
        for k, pred_label in enumerate(self.pred_arr):
            if self.turn_arr[k] == TurnLabel.AI.value:  # AI is speaking
                if (self.pred_arr[k - 3 : k] == "NA").all():  # AI pauses
                    if k < 3:
                        continue
                    if pred_label in ("C", "T"):
                        pred_arr.append(pred_label)
                        diff_val = (
                            self.true_arr_soft_label[k][LabelIndex.T.value]
                            - self.true_arr_soft_label[k][LabelIndex.C.value]
                        )
                        gt_arr.append(
                            "T" if diff_val > MetricThreshold.turn_change.value else "C"
                        )
        gt_arr = np.array(gt_arr)
        pred_arr = np.array(pred_arr)
        accuracy_pause = round(
            accuracy_score(gt_arr[pred_arr == "C"], pred_arr[pred_arr == "C"]) * 100, 1
        )  # AI continue
        print("Accuracy during Pause: " + str(accuracy_pause))
        accuracy_turn_change = round(
            accuracy_score(gt_arr[pred_arr == "T"], pred_arr[pred_arr == "T"]) * 100, 1
        )  # User speaks up
        print("Accuracy during Turn Change: " + str(accuracy_turn_change))
        return (
            accuracy_pause,
            accuracy_turn_change,
        )

    def handle_interruption_metric(self):
        """
        This metric aims to test the conversation capability:
        When system speaks ->  Handle user's interruptions
        For an ideal AI system, user should be able to make a
        successful interruption when turn change likelihood is high
        and not make one if continuation likelihood is high

        Returns:
            accuracy_unsuccess: Accuracy when user cannot successfully interrupt.
            accuracy_success: Accuracy when system let user
            make a successful interruption.
        """
        pred_arr = []
        gt_arr = []
        for k, pred_label in enumerate(self.pred_arr):
            if k < 1:
                continue
            if (
                self.turn_arr[k] == TurnLabel.AI_HUM.value
                and self.pred_arr[k - 1] == "IN"
            ):  # AI is being interrupted by Human
                if pred_label in ("C", "T"):
                    pred_arr.append(pred_label)
                    diff_val = (
                        self.true_arr_soft_label[k][LabelIndex.T.value]
                        - self.true_arr_soft_label[k][LabelIndex.C.value]
                    )
                    gt_arr.append(
                        "T"
                        if diff_val > MetricThreshold.success_interrupt.value
                        else "C"
                    )
        gt_arr = np.array(gt_arr)
        pred_arr = np.array(pred_arr)
        accuracy_unsuccess = round(
            accuracy_score(gt_arr[pred_arr == "C"], pred_arr[pred_arr == "C"]) * 100, 1
        )  # AI continue
        print("Accuracy during Unsuccessful Interruption: " + str(accuracy_unsuccess))
        if len(gt_arr[pred_arr == "T"]) == 0:
            print("No successful interruptions were made")
            accuracy_success = None
        else:
            accuracy_success = round(
                accuracy_score(gt_arr[pred_arr == "T"], pred_arr[pred_arr == "T"])
                * 100,
                1,
            )  # Human successfully takes turn
            print("Accuracy during Successful Interruption: " + str(accuracy_success))
        return (
            accuracy_unsuccess,
            accuracy_success,
        )
